/*
    Copyright 2006-2012 Patrik Jonsson, sunrise@familjenjonsson.org

    This file is part of Sunrise.

    Sunrise is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.

    Sunrise is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.

*/

/// \file
/// Load-balancing determination for the normal grid

#include "hilbert.h"
#include "partition.h"
#include "mpi_util.h"

/** Calculates the domain decomposition of the grid, using a Hilbert
    curve of order level. It first calculates the number of grid cells
    in each Hilbert cell. This is the work vector which is partitioned
    into mpi_size() number of blocks. It sets the domain_ vector with
    the hilbert codes of the partitions, and the domainstart_ vector
    with the starting cells in the partitions. */
template < typename cell_data_type>
void
mcrx::adaptive_grid<cell_data_type>::
domain_decomposition(const std::vector<bool>& structure, const int level)
{
  using namespace hilbert;

  domain_.clear();
  domainstart_.clear();  

  vec3d min(0,0,0);
  vec3d max(1,1,1);

  // level of Hilbert curve. A level N Hilbert curve has 8^N cells.
  int nhc=1;
  for(int i=0; i<level; ++i) nhc*=8;
  std::vector<int> work(nhc, 0);

  std::cout << "Determining domain decomposition using a level " 
       << level << " Hilbert curve" << std::endl;

  std::vector<bool>::const_iterator s=structure.begin();
  // first check is to make sure root cell is divided. if we only have
  // one cell, the whole exercise is moot anyway.
  if(*s++)
    calculate_work_recursive(s, structure.end(), cell_code(), work, level);

  /*
  std::cout << "Work vector is\n";
  for(std::vector<int>::iterator i=work.begin(); i!=work.end(); ++i)
    std::cout << *i << "\n";
  */

  // calculate domain partitions
  std::vector<size_t> parts=partition(mpi_size(), work);
  // the partition at the end point is not put in there by the algorithm
  parts.push_back(work.size());

  // change the partitions to Hilbert codes
  for(int i=0; i<parts.size(); ++i)
    domain_.push_back(cell_code(parts[i], level));

  std::vector<int> domaincells;
  if(is_mpi_master())
    std::cout << "Partitions are: \n";
  for(int i=0; i<mpi_size(); ++i) {
    const size_t this_work = 
      std::accumulate(&work[parts[i]], &work[parts[i+1]], size_t(0));
    domaincells.push_back(this_work);
    
    if(is_mpi_master())
      std::cout << i << "\t[" << domain_[i]<<'-'<<domain_[i+1] << "[\t"
	   << parts[i+1]-parts[i] << " hilbert cells\t" 
	   << this_work << " grid cells\n";
  }
  if(is_mpi_master())
    std::cout << "Load imbalance (max/min work): " 
	 << 1.0* *std::max_element(domaincells.begin(), domaincells.end())/
      *std::min_element(domaincells.begin(), domaincells.end()) << std::endl;

  // the start of the domains is the partial_sum of the number of cells.
  domainstart_.push_back(0);
  std::partial_sum(domaincells.begin(), domaincells.end(), 
		   back_inserter(domainstart_));
}


/** Does a virtual depth-first traversal of the octree to determine
    where the cells are using the structure vector, but without
    actually instantiating any objects. The code supplied is the code
    of the parent. */
template < typename cell_data_type>
void
mcrx::adaptive_grid<cell_data_type>::
calculate_work_recursive(std::vector<bool>::const_iterator& structure,
			 const std::vector<bool>::const_iterator& e,
			 const hilbert::cell_code code,
			 std::vector<int>& work,
			 const int decomp_level) const
{
  using namespace hilbert;

  // loop over subcells. Since they are stored in code order, we don't
  // actually have to worry about ordering here.
  for(uint8_t oct=0; oct<8; ++oct) {
    cell_code newcode(code);
    newcode.add_right(oct);
    assert(structure!=e);
    if(*structure++) {
      // this subcell is refined, recurse
      calculate_work_recursive(structure, e, 
			       newcode,
			       work, decomp_level);
    }
    else {
      // subcell is a leaf. add work
      if(decomp_level>newcode.level())
	newcode.extend_high(decomp_level);
      else if(decomp_level<newcode.level())
	newcode.truncate(decomp_level);

      assert(newcode.c_<work.size());
      work[newcode.c_] += 1;
    }
  }
}


	
